/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the
 * "License"). Please refer to the License for details. You may not use this
 * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN
 * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS
 * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository
 * for the full text of the License.
 */

#ifndef EXAMPLES_NORMALIZATION_WELFORDFINALIZE_CUSTOM_H
#define EXAMPLES_NORMALIZATION_WELFORDFINALIZE_CUSTOM_H
#include "kernel_operator.h"

constexpr uint32_t OUT_SIZE = 8;

namespace MyCustomKernel {
struct VecTiling {
    uint32_t rnLength;
    uint32_t abLength;
    uint32_t head;
    uint32_t headLength;
    uint32_t tail;
    uint32_t tailLength;
    uint32_t tmpLocalSize;
};

template <typename dataType, bool isReuseSource = false, bool isCounts = false, bool sharedTmp = false>
class KernelWelfordFinalize {
public:
    __aicore__ inline KernelWelfordFinalize()
    {}
    __aicore__ inline void Init(GM_ADDR inputMean_gm, GM_ADDR inputVariance_gm, GM_ADDR counts_gm, GM_ADDR outputMean_gm,
        GM_ADDR outputVariance_gm, VecTiling tilingData)
    {
        this->rnLength = tilingData.rnLength;
        this->abLength = tilingData.abLength;
        this->head = tilingData.head;
        this->headLength = tilingData.headLength;
        this->tail = tilingData.tail;
        this->tailLength = tilingData.tailLength;
        this->stackBufferSize = tilingData.tmpLocalSize;
        if (tailLength == 0){
            this->rLength = rnLength * abLength;
        } else {
            this->rLength = head * headLength + tail * tailLength;
        }
        this->abRec = 1.0f / abLength;
        this->rRec = 1.0f / rLength;
        this->outLength = OUT_SIZE;
 
        inputMean_global.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(inputMean_gm), abLength);
        inputVariance_global.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(inputVariance_gm), abLength);
        inputcounts_global.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(counts_gm), abLength);
        outputMean_global.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(outputMean_gm), outLength);
        outputVariance_global.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(outputVariance_gm), outLength);

        pipe.InitBuffer(inQueueMean, 1, abLength * sizeof(dataType));
        pipe.InitBuffer(inQueueVariance, 1, abLength * sizeof(dataType));
        pipe.InitBuffer(inQueueCounts, 1, abLength * sizeof(int32_t));
        pipe.InitBuffer(outQueueMean, 1, outLength * sizeof(dataType));
        pipe.InitBuffer(outQueueVariance, 1, outLength * sizeof(dataType));
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<dataType> inputMeanLocal = inQueueMean.AllocTensor<dataType>();
        AscendC::LocalTensor<dataType> inputVarianceLocal = inQueueVariance.AllocTensor<dataType>();
        AscendC::LocalTensor<int32_t> inputCountsLocal = inQueueCounts.AllocTensor<int32_t>();

        AscendC::DataCopy(inputMeanLocal, inputMean_global, abLength);
        AscendC::DataCopy(inputVarianceLocal, inputVariance_global, abLength);
        AscendC::DataCopy(inputCountsLocal, inputcounts_global, abLength);

        inQueueMean.EnQue(inputMeanLocal);
        inQueueVariance.EnQue(inputVarianceLocal);
        inQueueCounts.EnQue(inputCountsLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<dataType> inputMeanLocal = inQueueMean.DeQue<dataType>();
        AscendC::LocalTensor<dataType> inputVarianceLocal = inQueueVariance.DeQue<dataType>();
        AscendC::LocalTensor<int32_t> inputCountsLocal = inQueueCounts.DeQue<int32_t>();

        AscendC::LocalTensor<dataType> meanLocal = outQueueMean.AllocTensor<dataType>();
        AscendC::LocalTensor<dataType> varianceLocal = outQueueVariance.AllocTensor<dataType>();

        uint32_t minvalue = 0;
        uint32_t maxValue = 0;
        pipe.InitBuffer(sharedTmpBuffer, stackBufferSize);
        AscendC::LocalTensor<uint8_t> tmpLocalTensor = sharedTmpBuffer.Get<uint8_t>();

        struct AscendC::WelfordFinalizePara para = {rnLength, abLength, head, headLength, tail, tailLength, abRec, rRec};

        if constexpr (isCounts) {
            if constexpr (sharedTmp) {
                AscendC::WelfordFinalize<false>(meanLocal, varianceLocal, inputMeanLocal, inputVarianceLocal, inputCountsLocal, tmpLocalTensor, para);
            } else {
                AscendC::WelfordFinalize<false>(meanLocal, varianceLocal, inputMeanLocal, inputVarianceLocal, inputCountsLocal, para);
            }
        } else {
            if constexpr (sharedTmp) {
                AscendC::WelfordFinalize<false>(meanLocal, varianceLocal, inputMeanLocal, inputVarianceLocal, tmpLocalTensor, para);
            } else {
                AscendC::WelfordFinalize<false>(meanLocal, varianceLocal, inputMeanLocal, inputVarianceLocal, para);
            }
        }
        
        outQueueMean.EnQue<dataType>(meanLocal);
        outQueueVariance.EnQue<dataType>(varianceLocal);

        inQueueMean.FreeTensor(inputMeanLocal);
        inQueueVariance.FreeTensor(inputVarianceLocal);
        inQueueCounts.FreeTensor(inputCountsLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<dataType> meanLocal = outQueueMean.DeQue<dataType>();
        AscendC::LocalTensor<dataType> varianceLocal = outQueueVariance.DeQue<dataType>();

        AscendC::DataCopy(outputMean_global, meanLocal, outLength);
        AscendC::DataCopy(outputVariance_global, varianceLocal, outLength);

        outQueueMean.FreeTensor(meanLocal);
        outQueueVariance.FreeTensor(varianceLocal);
    }

private:
    AscendC::GlobalTensor<dataType> inputMean_global;
    AscendC::GlobalTensor<dataType> inputVariance_global;
    AscendC::GlobalTensor<int32_t> inputcounts_global;
    AscendC::GlobalTensor<dataType> outputMean_global;
    AscendC::GlobalTensor<dataType> outputVariance_global;

    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueMean;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueVariance;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueCounts;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueMean;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueVariance;
    AscendC::TBuf<AscendC::TPosition::VECCALC> sharedTmpBuffer;

    uint32_t rnLength;
    uint32_t abLength;
    uint32_t rLength;
    uint32_t head;
    uint32_t headLength;
    uint32_t tail;
    uint32_t tailLength;
    uint32_t outLength;
    float abRec;
    float rRec;
    bool inplace;

    uint32_t stackBufferSize = 0;
};

} // namespace MyCustomKernel

#endif // EXAMPLES_NORMALIZATION_WELFORDFINALIZE_CUSTOM_H
